R"""
"""
#
import torch
from typing import Tuple, List
from .dyngraph import DynamicGraph
from ..model.model import Model
from ..model.mlp import MLP, SDGNN_MLP, GLNN_MLP
from ..model.activate import activatize
from .regression import loss, metrics
from collections import OrderedDict



class TrafficCross(DynamicGraph):
    R"""
    Traffic at crossing prediction.
    It is a temporal final node regression task.
    """
    def __init__(
        self,
        tgnn: Model, target_feat_size: int, embed_inside_size: int,
        /,
        *,
        activate: str, notembedon: List[int],
    ) -> None:
        R"""
        Initialize the class.
        """
        #
        DynamicGraph.__init__(self)

        #
        self.target_feat_size = target_feat_size

        #
        self.tgnn = tgnn
        #
        self.notembedon = notembedon
        if len(self.notembedon) == 0:
            #
            self.mlp = (
                MLP(
                    self.tgnn.feat_target_size, target_feat_size,
                    embed_inside_size,
                    activate=activate,
                )
            )
            self.activate = activatize(activate)
        else:
            #
            self.tgnn.moveon(self.notembedon)

        self.transformation_model = (
                SDGNN_MLP(
                    self.tgnn.snn_node.hidden_size, self.tgnn.feat_target_size,
                    embed_inside_size,
                    activate=activate,
                )
            )
        
        self.glnn_rnn = (
                GLNN_MLP(
                    self.tgnn.snn_node.hidden_size, self.tgnn.feat_target_size,
                    embed_inside_size,
                    activate=activate,
                )
        )


        self.new_mlp = (
                MLP(
                    self.tgnn.feat_target_size, target_feat_size,
                    embed_inside_size,
                    activate=activate,
                )
        )

        self.training_stage = True

        print("=" * 10 + " " + "\x1b[37mConfiguration of\x1b[0m: \"mlp\"" + " " + "=" * 10)


    def continue_train_with_pretrain_mlp(self, path: str, /) -> None:
        R"""
        Use pretrained model.
        """
        #
        if len(path) == 0:
            #
            return
        print("=" * 10 + " " + "\x1b[37mWill use pretrained low level mlp model to continue train with \x1b[0m: \"mlp\"" + " " + "=" * 10)
        # Overwrite parameters by pretrained state dict.
        state_dict = torch.load(path)
        state_dict_mlp = OrderedDict()
        for key in state_dict.keys():
            state_dict_mlp[key] = state_dict["{:s}".format(key)]
        self.mlp.load_state_dict(state_dict_mlp)


    def continue_train_with_pretrain_transformation_model(self, path: str, /) -> None:
        R"""
        Use pretrained model.
        """
        #
        if len(path) == 0:
            #
            return
        print("=" * 10 + " " + "\x1b[37mWill use pretrained transformation model to continue train with \x1b[0m: \"mlp\"" + " " + "=" * 10)
        # Overwrite parameters by pretrained state dict.
        state_dict = torch.load(path)
        state_dict_mlp = OrderedDict()
        for key in state_dict.keys():
            state_dict_mlp[key] = state_dict["{:s}".format(key)]
        self.transformation_model.load_state_dict(state_dict_mlp)


    def reset(self, rng: torch.Generator, /) -> int:
        R"""
        Reset model parameters by given random number generator.
        """
        #
        resetted = 0
        resetted = resetted + self.tgnn.reset(rng)
        if len(self.notembedon) == 0:
            #
            resetted = resetted + self.mlp.reset(rng)
        return resetted

    def forward(
        self,
        edge_tuples: torch.Tensor, edge_feats: torch.Tensor,
        edge_labels: torch.Tensor, edge_ranges: torch.Tensor,
        edge_times: torch.Tensor, node_feats: torch.Tensor,
        node_labels: torch.Tensor, node_times: torch.Tensor,
        node_masks: torch.Tensor,
        /,
    ) -> List[torch.Tensor]:
        R"""
        Forward.
        """
        #
        node_embeds: torch.Tensor

        # We do not have label emebdding layers.
        if edge_labels.ndim > 0 or node_labels.ndim > 0:
            # UNEXPECT:
            # Current tasks does not assume any edge or node label input
            # embeddings.
            raise NotImplementedError(
                "Edge or node label input is not supported.",
            )
        #
        # Specify mean and standard deviation of Gaussian noise
        mean = 0.0
        # std = 0.05
        # std = 0.13
        # std = 0.2
        # std = 0.01
        # std=0.05
        # std=0.005
        # std=0.01
        std=0.00

        # std = 0.5

        node_embeds = (
            self.tgnn.forward(
                edge_tuples, edge_feats, edge_ranges, edge_times, node_feats,
                node_times, node_masks,
            )
        )
        if len(self.notembedon) == 0:
            # #
            # # Generate Gaussian noise
            noise = torch.randn(node_embeds.size()) * std + mean

            # # Add the noise to the original tensor
            node_embeds = node_embeds + noise.to(node_embeds.device)
            # node_embeds = self.mlp(self.activate(node_embeds))
            node_embeds = self.mlp(node_embeds)
        return [node_embeds]

    def loss(self, /, *ARGS) -> torch.Tensor:
        R"""
        Loss funtion.
        """
        #
        node_output_feats: torch.Tensor
        node_target_feats: torch.Tensor
        node_masks: torch.Tensor

        # Output only has node feature-like data.
        # Target node label data are not useful in this task.
        (node_output_feats, node_target_feats, _, node_masks) = ARGS

        # Format output and target data.
        node_exists = node_masks > 0
        node_output_feats = (
            torch.reshape(
                node_output_feats,
                (len(node_output_feats), self.target_feat_size),
            )[node_exists]
        )
        node_target_feats = (
            torch.reshape(
                node_target_feats,
                (len(node_target_feats), self.target_feat_size),
            )[node_exists]
        )
        return loss(node_output_feats, node_target_feats)

    def metrics(self, /, *ARGS) -> List[Tuple[int, float]]:
        R"""
        Evaluation metrics.
        """
        #
        node_output_feats: torch.Tensor
        node_target_feats: torch.Tensor
        node_masks: torch.Tensor

        # Output only has node feature-like data.
        # Target node label data are not useful in this task.
        (node_output_feats, node_target_feats, _, node_masks) = ARGS

        # Format output and target data.
        node_exists = node_masks > 0
        node_output_feats = (
            torch.reshape(
                node_output_feats,
                (len(node_output_feats), self.target_feat_size),
            )[node_exists]
        )
        node_target_feats = (
            torch.reshape(
                node_target_feats,
                (len(node_target_feats), self.target_feat_size),
            )[node_exists]
        )
        return metrics(node_output_feats, node_target_feats)